Gemma/[Gemma_2]RAG_LlamaIndex.ipynb (780 lines of code) (raw):

{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "Tce3stUlHN0L" }, "source": [ "##### Copyright 2024 Google LLC." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "tuOe1ymfHZPu" }, "outputs": [], "source": [ "# @title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "dfsDR_omdNea" }, "source": [ "# RAG with Gemma and LlamaIndex\n", "\n", "This notebook demonstrates how to integrate Gemma model with [LlamaIndex](https://www.llamaindex.ai/) library to build a basic RAG application.\n", "<table align=\"left\">\n", " <td>\n", " <a target=\"_blank\" href=\"https://colab.research.google.com/github/google-gemini/gemma-cookbook/blob/main/Gemma/[Gemma_2]RAG_LlamaIndex.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n", " </td>\n", "</table>" ] }, { "cell_type": "markdown", "metadata": { "id": "FaqZItBdeokU" }, "source": [ "## Setup\n", "\n", "### Select the Colab runtime\n", "To complete this tutorial, you'll need to have a Colab runtime with sufficient resources to run the Gemma model. In this case, you can use a T4 GPU:\n", "\n", "1. In the upper-right of the Colab window, select **▾ (Additional connection options)**.\n", "2. Select **Change runtime type**.\n", "3. Under **Hardware accelerator**, select **T4 GPU**.\n", "\n", "### Gemma setup\n", "\n", "To complete this tutorial, you'll first need to complete the setup instructions at [Gemma setup](https://ai.google.dev/gemma/docs/setup). The Gemma setup instructions show you how to do the following:\n", "\n", "* Get access to Gemma on kaggle.com.\n", "* Select a Colab runtime with sufficient resources to run\n", " the Gemma 2B model.\n", "* Generate and configure a Kaggle username and an API key as Colab secrets.\n", "\n", "After you've completed the Gemma setup, move on to the next section, where you'll set environment variables for your Colab environment.\n" ] }, { "cell_type": "markdown", "metadata": { "id": "CY2kGtsyYpHF" }, "source": [ "### Configure your credentials\n", "\n", "Add your your Kaggle credentials to the Colab Secrets manager to securely store it.\n", "\n", "1. Open your Google Colab notebook and click on the 🔑 Secrets tab in the left panel. <img src=\"https://storage.googleapis.com/generativeai-downloads/images/secrets.jpg\" alt=\"The Secrets tab is found on the left panel.\" width=50%>\n", "2. Create new secrets: `KAGGLE_USERNAME` and `KAGGLE_KEY`\n", "3. Copy/paste your username into `KAGGLE_USERNAME`\n", "3. Copy/paste your key into `KAGGLE_KEY`\n", "4. Toggle the buttons on the left to allow notebook access to the secrets.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "A9sUQ4WrP-Yr" }, "outputs": [], "source": [ "import os\n", "from google.colab import userdata\n", "\n", "# Set the backbend before importing Keras\n", "os.environ[\"KERAS_BACKEND\"] = \"jax\"\n", "# Avoid memory fragmentation on JAX backend.\n", "os.environ[\"XLA_PYTHON_CLIENT_MEM_FRACTION\"] = \"1.00\"\n", "\n", "# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env\n", "# vars as appropriate for your system.\n", "os.environ[\"KAGGLE_USERNAME\"] = userdata.get(\"KAGGLE_USERNAME\")\n", "os.environ[\"KAGGLE_KEY\"] = userdata.get(\"KAGGLE_KEY\")" ] }, { "cell_type": "markdown", "metadata": { "id": "iwjo5_Uucxkw" }, "source": [ "### Install dependencies\n", "Run the cell below to install all the required dependencies." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "r_nXPEsF7UWQ" }, "outputs": [], "source": [ "!pip install -q -U tensorflow keras keras-nlp\n", "!pip install -q -U keras keras-nlp\n", "!pip install llama-index-embeddings-instructor\n", "!pip install sentence-transformers==2.2.2\n", "!pip install llama-index-readers-web llama-index-readers-file" ] }, { "cell_type": "markdown", "metadata": { "id": "J3sX2mFH4GWk" }, "source": [ "### Gemma" ] }, { "cell_type": "markdown", "metadata": { "id": "IgJWrlCC7v27" }, "source": [ "**About Gemma**\n", "\n", "Gemma is a family of lightweight, state-of-the-art open models from Google, built from the same research and technology used to create the Gemini models. They are text-to-text, decoder-only large language models, available in English, with open weights, pre-trained variants, and instruction-tuned variants. Gemma models are well-suited for a variety of text generation tasks, including question answering, summarization, and reasoning. Their relatively small size makes it possible to deploy them in environments with limited resources such as a laptop, desktop or your own cloud infrastructure, democratizing access to state of the art AI models and helping foster innovation for everyone." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "B3WckZv2hef3" }, "outputs": [], "source": [ "import keras\n", "import keras_nlp\n", "\n", "os.environ[\"KAGGLE_USERNAME\"] = userdata.get(\"KAGGLE_USERNAME\")\n", "os.environ[\"KAGGLE_KEY\"] = userdata.get(\"KAGGLE_KEY\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8dfseDZChhjl" }, "outputs": [], "source": [ "# Let's load Gemma using Keras\n", "gemma_model_id = \"gemma2_instruct_2b_en\"\n", "gemma = keras_nlp.models.GemmaCausalLM.from_preset(gemma_model_id)" ] }, { "cell_type": "markdown", "metadata": { "id": "6xKBqqTD3zyq" }, "source": [ "## LlamaIndex" ] }, { "cell_type": "markdown", "metadata": { "id": "g_YESbFQ8DlU" }, "source": [ "LlamaIndex is a toolkit for developers to build applications that use large language models (LLMs) with specific data. This data can be private or related to a particular field. With LlamaIndex, developers can create various LLM applications, including question-answering chatbots, document analysis tools, and even autonomous agents. The toolkit offers functions to process data and design workflows that combine data retrieval with instructions for the LLM." ] }, { "cell_type": "markdown", "metadata": { "id": "vO2BH5iX8Lot" }, "source": [ "Large language models (LLMs) are powerful but lack your specific data. Retrieval-Augmented Generation (RAG) bridges this gap by incorporating your data for improved performance. RAG works by indexing your data for efficient retrieval based on user queries. The most relevant information, along with the query itself, is then fed to the LLM to generate a response. Understanding RAG is essential for building LLM applications like chatbots and agents." ] }, { "cell_type": "markdown", "metadata": { "id": "RZEp91Gb8pP9" }, "source": [ "### Setup" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "NJFTVCbP31P_" }, "outputs": [], "source": [ "from typing import Optional, List, Mapping, Any\n", "from llama_index.core.node_parser import SentenceSplitter\n", "from llama_index.core import Settings, VectorStoreIndex, SimpleDirectoryReader, PromptTemplate\n", "from llama_index.embeddings.instructor import InstructorEmbedding\n", "from llama_index.core.llms import (\n", " CustomLLM,\n", " CompletionResponse,\n", " CompletionResponseGen,\n", " LLMMetadata,\n", ")\n", "from llama_index.core.llms.callbacks import llm_completion_callback" ] }, { "cell_type": "markdown", "metadata": { "id": "s8jsT8dr8rjW" }, "source": [ "To ensure compatibility between Gemma and the LlamaIndex library, you need to creata a simple interface class. The provided code implements basic generation methods, allowing the library to interact with our model effectively." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "njdY5e7f4IW-" }, "outputs": [], "source": [ "class GemmaLLMInterface(CustomLLM):\n", " model: keras_nlp.models.GemmaCausalLM = None\n", " context_window: int = 8192\n", " num_output: int = 2048\n", " model_name: str = \"gemma_2\"\n", "\n", " def _format_prompt(self, message: str) -> str:\n", " return (\n", " f\"<start_of_turn>user\\n{message}<end_of_turn>\\n\" f\"<start_of_turn>model\\n\"\n", " )\n", "\n", " @property\n", " def metadata(self) -> LLMMetadata:\n", " \"\"\"Get LLM metadata.\"\"\"\n", " return LLMMetadata(\n", " context_window=self.context_window,\n", " num_output=self.num_output,\n", " model_name=self.model_name,\n", " )\n", "\n", " @llm_completion_callback()\n", " def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:\n", " prompt = self._format_prompt(prompt)\n", " raw_response = self.model.generate(prompt, max_length=self.num_output)\n", " response = raw_response[len(prompt) :]\n", " return CompletionResponse(text=response)\n", "\n", " @llm_completion_callback()\n", " def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:\n", " response = self.complete(prompt).text\n", " for token in response:\n", " response += token\n", " yield CompletionResponse(text=response, delta=token)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Y9sar4FazDI_" }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "707339b0c08049cf81b15e0431bff054", "version_major": 2, "version_minor": 0 }, "text/plain": [ ".gitattributes: 0%| | 0.00/1.48k [00:00<?, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "095cd08624e64f52aa81891a434effb8", "version_major": 2, "version_minor": 0 }, "text/plain": [ "1_Pooling/config.json: 0%| | 0.00/270 [00:00<?, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "64456f13bbbb407f81f1d01e5e5acace", "version_major": 2, "version_minor": 0 }, "text/plain": [ "2_Dense/config.json: 0%| | 0.00/115 [00:00<?, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "73f01f30d8604a10ab2f39d696ddbfba", "version_major": 2, "version_minor": 0 }, "text/plain": [ "pytorch_model.bin: 0%| | 0.00/2.36M [00:00<?, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "e9c40fb92b384c4abfa2afccdfd81d87", "version_major": 2, "version_minor": 0 }, "text/plain": [ "README.md: 0%| | 0.00/66.2k [00:00<?, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "f785db0c4967428c9136a42836658166", "version_major": 2, "version_minor": 0 }, "text/plain": [ "config.json: 0%| | 0.00/1.55k [00:00<?, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "82d6085fe86c44cf99e72365ce3b07d1", "version_major": 2, "version_minor": 0 }, "text/plain": [ "config_sentence_transformers.json: 0%| | 0.00/122 [00:00<?, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "8f9bc5fe0ce34c6cad05371abf3a9c60", "version_major": 2, "version_minor": 0 }, "text/plain": [ "pytorch_model.bin: 0%| | 0.00/439M [00:00<?, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "5e370509c6654bc6827a74cbcd1f7ef7", "version_major": 2, "version_minor": 0 }, "text/plain": [ "sentence_bert_config.json: 0%| | 0.00/53.0 [00:00<?, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "59656b54c17940b19dc6f42e19e1a61b", "version_major": 2, "version_minor": 0 }, "text/plain": [ "special_tokens_map.json: 0%| | 0.00/2.20k [00:00<?, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "3b99ef33ce4c4aa2b2c030cb7fd082ef", "version_major": 2, "version_minor": 0 }, "text/plain": [ "spiece.model: 0%| | 0.00/792k [00:00<?, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "7b041bf7e06f4886bb21c5d282c2a676", "version_major": 2, "version_minor": 0 }, "text/plain": [ "tokenizer.json: 0%| | 0.00/2.42M [00:00<?, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "62c6c45a64634d9a8ed47105e6bfb6cb", "version_major": 2, "version_minor": 0 }, "text/plain": [ "tokenizer_config.json: 0%| | 0.00/2.43k [00:00<?, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "1b7a145d712e476cbdf3cb09dadb6bf1", "version_major": 2, "version_minor": 0 }, "text/plain": [ "modules.json: 0%| | 0.00/461 [00:00<?, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "load INSTRUCTOR_Transformer\n", "max_seq_length 512\n" ] } ], "source": [ "# This settings define what models will be used by LlamaIndex\n", "Settings.embed_model = InstructorEmbedding(model_name=\"hkunlp/instructor-base\")\n", "Settings.llm = GemmaLLMInterface(model=gemma)" ] }, { "cell_type": "markdown", "metadata": { "id": "2t5WLak69hd8" }, "source": [ "## Retrieval-Augmented Generation (RAG)" ] }, { "cell_type": "markdown", "metadata": { "id": "A9iQpOQr91Uw" }, "source": [ "RAGs improve generative AI outputs through a multi-step process. First, it searches for relevant external data (webpages, databases) using powerful algorithms. This retrieved information is then cleaned and prepped for the LLM. Finally, the prepped data is fed alongside the original query into the LLM. This extra context allows the LLM to understand the topic better, resulting in more precise, informative, and engaging responses." ] }, { "cell_type": "markdown", "metadata": { "id": "Ulg3ptNf92I3" }, "source": [ "This notebook demonstrates how to build an RAG application by utilising Paul Graham's essay. The essay is used as a placeholder to illustrate the RAG concepts without introducing complexities of real-world data. It simplifies the process by focusing on a single, well-defined source (the essay) to showcase how data retrieval enhances LLM performance in a RAG system." ] }, { "cell_type": "markdown", "metadata": { "id": "j3tnN_qMAmbd" }, "source": [ "### Chunking the data" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ubRH-t77AJIx" }, "outputs": [], "source": [ "# Let's download the data frist\n", "!wget -q \"https://raw.githubusercontent.com/run-llama/llama_index/main/docs/docs/examples/data/paul_graham/paul_graham_essay.txt\" \"paul_graham_essay.txt\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "53aGXIuor1Nv" }, "outputs": [], "source": [ "# Reading documents from disk\n", "documents = SimpleDirectoryReader(input_files=[\"paul_graham_essay.txt\"]).load_data()\n", "\n", "# Splitting the document into chunks with\n", "# predefined size and overlap\n", "parser = SentenceSplitter.from_defaults(\n", " chunk_size=256, chunk_overlap=64, paragraph_separator=\"\\n\\n\"\n", ")\n", "nodes = parser.get_nodes_from_documents(documents)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Fcr14ctm-msd" }, "outputs": [ { "data": { "application/vnd.google.colaboratory.intrinsic+json": { "type": "string" }, "text/plain": [ "'What I Worked On\\n\\nFebruary 2021\\n\\nBefore college the two main things I worked on, outside of school, were writing and programming. I didn\\'t write essays. I wrote what beginning writers were supposed to write then, and probably still are: short stories. My stories were awful. They had hardly any plot, just characters with strong feelings, which I imagined made them deep.\\n\\nThe first programs I tried writing were on the IBM 1401 that our school district used for what was then called \"data processing.\" This was in 9th grade, so I was 13 or 14. The school district\\'s 1401 happened to be in the basement of our junior high school, and my friend Rich Draves and I got permission to use it. It was like a mini Bond villain\\'s lair down there, with all these alien-looking machines — CPU, disk drives, printer, card reader — sitting up on a raised floor under bright fluorescent lights.'" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Example node:\n", "nodes[0].text" ] }, { "cell_type": "markdown", "metadata": { "id": "Ox_vtnkRAp5X" }, "source": [ "### Building Vector Store" ] }, { "cell_type": "markdown", "metadata": { "id": "bswW6LMI-9Kc" }, "source": [ "Now you can now build a search engine that finds the best parts of text that answer a user's question." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "BE-OyRUKr1wj" }, "outputs": [], "source": [ "# Converting the vector store to retrevier\n", "query_engine = VectorStoreIndex(nodes).as_query_engine(\n", " similarity_top_k=3, response_mode=\"tree_summarize\"\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "62pfQ9AZ_yI7" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Found: 3 relevant chunks\n", "1) She liked to paint on big, square canvases, 4 to 5 feet on a sid...\n", "2) The next year, from the summer of 1998 to the summer of 1999, mu...\n", "3) It was a lot of fun working with Robert and Trevor. They're the ...\n" ] } ], "source": [ "# Let's test it out\n", "relevant_chunks = query_engine.retrieve(\"1992\")\n", "print(f\"Found: {len(relevant_chunks)} relevant chunks\")\n", "for idx, chunk in enumerate(relevant_chunks):\n", " print(f\"{idx + 1}) {chunk.text[:64]}...\")" ] }, { "cell_type": "markdown", "metadata": { "id": "ImSNuY6GATgJ" }, "source": [ "Those chunks will be inject to the LLM's prompt in order to answer user query." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "OX_fcTcshexQ" }, "outputs": [], "source": [ "# (Optional) Gemma works better with straightforward prompts without\n", "# additional tokens needed to separate sections.\n", "# You can simply update it to get better results:\n", "new_summary_tmpl_str = \"\"\"Text:\n", "{context_str}\n", "According to the text answer the query: {query_str}\"\"\"\n", "\n", "query_engine.update_prompts(\n", " {\"response_synthesizer:summary_template\": PromptTemplate(new_summary_tmpl_str)}\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "6Cr_g3JqA2sE" }, "source": [ "So when an user ask a question, the following prompt will be send to the LLM:\n", "\n", "```\n", "Text:\n", "<chunk #1>\n", "<chunk #2>\n", "<chunk #3>\n", "According to the text answer the query: <question>\n", "```\n", "\n", "By providing the large language model (LLM) with additional context, it can generate more accurate and informative responses to user queries.\n" ] }, { "cell_type": "markdown", "metadata": { "id": "4iJxZAJi4qi_" }, "source": [ "### Test it yourself!" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "R64H_BSPNPW9" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "ASP stands for **Application Service Provider**. \n", "<end_of_turn>\n" ] } ], "source": [ "response = query_engine.query(\"What does ASP stand for?\")\n", "print(response)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "J_YjGXFXiao0" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The text mentions that Paul Graham applied to RISD (Rhode Island School of Design) and the Accademia. \n", "<end_of_turn>\n" ] } ], "source": [ "response = query_engine.query(\"What art schools did Paul Graham apply to?\")\n", "print(response)" ] } ], "metadata": { "accelerator": "GPU", "colab": { "name": "[Gemma_2]RAG_LlamaIndex.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }